% trainPrior: Learn prior distribution on illuminants
%
% trainPrior(lst,saveloc,true_l)
%
%   lst:     Cell (Nx1) containing paths of pre-computed statistics files
%   saveloc: Path of training .mat file passed to train (which
%            should be called first). This function will create a
%            new file called saveloc.w.mat
%   true_l:  Nx3 Array containing ground truth illuminants 
%            for each file. If this parameter is ommitted, all illuminants
%            will be assumed to be [1 1 1].
function trainPrior(lst,saveloc,true_l)

  cov_t = [];
  load('-mat',saveloc);
  saveloc = [saveloc '.w.mat'];saveloc2 = saveloc;
  
  num_bands = size(cov_t,3);
    
  true_l = true_l ./ repmat(sqrt(sum(true_l.^2,2)),[1 3]);
  w = 1./true_l;
  Q = (w' * w)/size(w,1); Q = inv(Q);
  
  A = zeros(3,3,length(lst));
    
  step = 10;
  if length(lst) < 200
    step = 1;
  end;
  
  for i = 1:step:length(lst)
      fprintf('\r File %d of %d         ',i,length(lst));
      fouts = loadpc(lst{i});
      
      [Ai,NK] = getA(fouts,cov_t,true_l(i,:));
      ell = atoell(Ai);
      A(:,:,i) = Ai*NK / sum(ell.^2);
  end;
  fprintf('\n');
  lft = 2; rght = 8;
  for hier = 1:4
    alphas = linspace(lft,rght,4);
    score = zeros(size(alphas));
    for i = 1:length(alphas)
      for j = 1:length(lst)
        score(i) = score(i) + ferr(true_l(j,:),atoell(A(:,:,j)+Q*10^alphas(i)));
      end;
    end;
    [score,idx] = sort(score); alpha = alphas(idx(1));
    lft = alpha - 0.8*(alphas(2)-alphas(1));
    rght = alpha + 0.8*(alphas(2)-alphas(1));
  end;
  alpha = 10^alpha;
  
  save('-mat',saveloc2,'cov_t','alpha','Q');
    
function g = ferr(lo,l)
  lo = lo(:) / sqrt(sum(lo.^2));
  l = l(:) / sqrt(sum(l.^2));
  g = acosd(sum(l.*lo));

function [A,NK] = getA(fouts,si,m)
  A = zeros(3,3);
  NK = 0;
  for i = 1:length(fouts)
    d = fouts{i};
    
    NK = NK + size(d,1);
    dm = d * diag(1./m);
    wts = sum(dm .* (dm * si(:,:,i)),2); wts = abs(wts).^(0.25);
    
    d = d ./ repmat(max(wts,10^-24),[1 3]);
    A = A + (d'*d) .* si(:,:,i);
  end;
  A = 4*A/NK;
  
function ell = atoell(A)
  ell = sqrt(diag(A));
  C = inf;
  for it = 1:10
    for i = 1:3
      sj = sum(A(:,i) ./ ell(:)) - A(i,i)/ell(i);
      ell(i) = 0.5*(sj + sqrt(sj^2 + 4*A(i,i)));
    end;
    Cn = sum(log(ell)) + 0.5*ell'*A*ell;
    if Cn > (1-10^-4)*C
      break;
    end;
    C = Cn;
  end;
